import csv
from typing import Iterator

import json
from allennlp.data import Instance

from config import Config
from tools.utils import create_dir_of_file
from .dataset_reader import DatasetReader


class SNLIDatasetReader(DatasetReader):
    def __init__(self, cf: Config, token_type=None):
        super(SNLIDatasetReader, self).__init__(cf, token_type, lazy=False)

    def _read(self, file_path: str) -> Iterator[Instance]:
        with open(file_path, "r") as f:
            csv_file = csv.reader(f)
            for row in csv_file:
                yield self.text_to_instance(f'{row[1]} {row[2]}', row[0])
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip('\n')
                try:
                    line = json.loads(line)
                except Exception:
                    continue

                label = line['gold_label']
                sentence = f'{line["sentence1"]} [SEP] {line["sentence2"]}'
                if label == '-':
                    continue

                yield self.text_to_instance(sentence, label)

    def _write(self, file_path: str, data):
        create_dir_of_file(file_path)
        with open(file_path, 'w') as f:
            for d in data:
                item = {}
                item['gold_label'] = d['label']
                sentences = d['sentence'].split('[SEP]')
                item['sentence1'] = sentences[0]
                item['sentence2'] = sentences[1]
                json_d = json.dumps(item)
                f.write(json_d)
